Skip to content

Lance/nnx all#3425

Draft
ecnal-cienet wants to merge 16 commits intomainfrom
lance/nnx_all
Draft

Lance/nnx all#3425
ecnal-cienet wants to merge 16 commits intomainfrom
lance/nnx_all

Conversation

@ecnal-cienet
Copy link
Copy Markdown
Collaborator

@ecnal-cienet ecnal-cienet commented Mar 16, 2026

Description

Start with a short description of what the PR does and how this is a change from
the past.

The rest of the description includes relevant details and context, examples:

  • why is this change being made,
  • the problem being solved and any relevant context,
  • why this is a good solution,
  • some information about the specific implementation,
  • shortcomings of the solution and possible future improvements.

If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456

Notice 1: Once all tests pass, the "pull ready" label will automatically be assigned.
This label is used for administrative purposes. Please do not add it manually.

Notice 2: For external contributions, our settings currently require an approval from a MaxText maintainer to trigger CI tests.

Tests

Please describe how you tested this change, and include any instructions and/or
commands to reproduce.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@ecnal-cienet ecnal-cienet force-pushed the lance/nnx_all branch 2 times, most recently from 08d190d to 3ed2aad Compare March 16, 2026 21:24
Charles Li and others added 16 commits March 19, 2026 14:56
- pure_nnx: a flag to to choose pure NNX logic when NNX and linen models
  co-exist.
- init_state_fn: a function to initialize the model state for the
  training. It will be set to different function for NNX and Linen.
- Add utils to manipulate the NNX shardings with abstract state of a
  model
  - also add unit tests for the utils
- Extract mesh creation function to maxtext_utils.get_mesh_from_config()
  - also add unit tests for this func

Note:
flax v0.12 has DeprecationWarning in multiple places:
  - DeprecationWarning: '.value' access is now deprecated. Use
    variable.get_value() or variable[...] (for [Array]).
  - DeprecationWarning: 'VariableState' was removed, this is just
    an alias to 'Variable'. Plase use 'Variable' directly instead.
But since the code needs to work with post-training, which currently
requires flax v0.11, we didn't change code for these warnings.
A TrainState for NNX, which includes model and optimizer
Unit tests include checkpoint tests:
- restore a saved state
- convert linen TrainState to NNX TrainState
- Parameter only restore (no opt_state)
…ion_utils

Also added unit tests.
Refactored model_creation_utils to provide common create_nnx_abstract_model() func.

b/src/maxtext/utils/model_creation_utils.py
1. A new func get_abstract_state_nnx() is added to maxtext_utils.py
The it will be called during training to create NNX training state.

Same as the linen version, it handles shard_optimizer_over_data,
optimizer_memory_host_offload, and parameter_memory_host_offload

Unit tests are added to this NNX func.

2. Add nnx train_state handling in train_utils.py

DPO handling will be supported (or removed) later in train_utils.py
Also added unit tests for NNX model.
- get_functional_train_with_signature: use (state, batch) shardings when pure_nnx=True
- get_functional_eval_with_signature: use (state, batch) shardings when pure_nnx=True
- Convert nnx.State to pure dict for checkpoint saving
- Restore pure dict back to nnx.State after loading
Add a bidirectional Linen <-> NNX checkpoint converter tool that handles:
  - Auto-detection of checkpoint format
  - Conversion of params structure (double nesting vs flat)
  - Stacking/unstacking per-layer parameters
  - Value wrapper handling for NNX format
  Add a tool to compare checkpoint tree structures, shapes, and values
  across Linen and NNX formats.Supports cross-format and same-format comparisons with auto-detection, layer axis transposition, and RNG filtering.
- Add --pure_nnx CLI flag to run_sharding_dump.py
- Propagate pure_nnx=true to the sharding_dump subprocess when flag is set
- Refactor run_single_dump() to build the command as a list for conditional flag appending
- Replace nn.Dropout with linears.Dropout in gpt_oss and olmo3 decoder layers
- Add num_activations logical axis rule to base.yml
- Fix integration and unit tests for NNX compatibility

I will relocate these files accordingly once the work is done.
@github-actions
Copy link
Copy Markdown

This PR has been automatically marked as stale because it has not had recent activity. It will be closed soon if no further activity occurs. Thank you for your contributions.

@github-actions github-actions Bot added the stale Automatically applied to stale PRs. label Apr 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

stale Automatically applied to stale PRs.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants